//===--- Evaluator.cpp - Request Evaluator Implementation -----------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file implements the Evaluator class that evaluates and caches
// requests.
//
//===----------------------------------------------------------------------===//
#include "polarphp/ast/Evaluator.h"
#include "polarphp/ast/DiagnosticEngine.h"
#include "polarphp/basic/Range.h"
#include "polarphp/basic/SourceMgr.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FileSystem.h"

namespace polar {

AnyRequest::HolderBase::~HolderBase() { }

std::string AnyRequest::getAsString() const {
   std::string result;
   {
      llvm::raw_string_ostream out(result);
      simple_display(out, *this);
   }
   return result;
}

AbstractRequestFunction *
Evaluator::getAbstractRequestFunction(uint8_t zoneID, uint8_t requestID) const {
   for (const auto &zone : requestFunctionsByZone) {
      if (zone.first == zoneID) {
         if (requestID < zone.second.size())
            return zone.second[requestID];

         return nullptr;
      }
   }

   return nullptr;
}

void Evaluator::registerRequestFunctions(
   Zone zone,
   ArrayRef<AbstractRequestFunction *> functions) {
   uint8_t zoneID = static_cast<uint8_t>(zone);
#ifndef NDEBUG
   for (const auto &zone : requestFunctionsByZone) {
      assert(zone.first != zoneID);
   }
#endif

   requestFunctionsByZone.push_back({zoneID, functions});
}

Evaluator::Evaluator(DiagnosticEngine &diags, bool debugDumpCycles)
   : diags(diags), debugDumpCycles(debugDumpCycles) { }

void Evaluator::emitRequestEvaluatorGraphViz(llvm::StringRef graphVizPath) {
   std::error_code error;
   llvm::raw_fd_ostream out(graphVizPath, error, llvm::sys::fs::F_Text);
   printDependenciesGraphviz(out);
}

bool Evaluator::checkDependency(const AnyRequest &request) {
   // If there is an active request, record it's dependency on this request.
   if (!activeRequests.empty())
      dependencies[activeRequests.back()].push_back(request);

   // Record this as an active request.
   if (activeRequests.insert(request)) {
      return false;
   }

   // Diagnose cycle.
   diagnoseCycle(request);

   if (debugDumpCycles) {
      llvm::errs() << "===CYCLE DETECTED===\n";
      llvm::DenseSet<AnyRequest> visitedAnywhere;
      llvm::SmallVector<AnyRequest, 4> visitedAlongPath;
      std::string prefixStr;
      printDependencies(activeRequests.front(), llvm::errs(), visitedAnywhere,
                        visitedAlongPath, activeRequests.getArrayRef(),
                        prefixStr, /*lastChild=*/true);
   }

   return true;
}

void Evaluator::diagnoseCycle(const AnyRequest &request) {
   request.diagnoseCycle(diags);
   for (const auto &step : llvm::reverse(activeRequests)) {
      if (step == request) return;

      step.noteCycleStep(diags);
   }

   llvm_unreachable("Diagnosed a cycle but it wasn't represented in the stack");
}

void Evaluator::printDependencies(
   const AnyRequest &request,
   llvm::raw_ostream &out,
   llvm::DenseSet<AnyRequest> &visitedAnywhere,
   llvm::SmallVectorImpl<AnyRequest> &visitedAlongPath,
   llvm::ArrayRef<AnyRequest> highlightPath,
   std::string &prefixStr,
   bool lastChild) const {
   out << prefixStr << " `--";

   // Determine whether this node should be highlighted.
   bool isHighlighted = false;
   if (std::find(highlightPath.begin(), highlightPath.end(), request)
       != highlightPath.end()) {
      isHighlighted = true;
      out.changeColor(llvm::buffer_ostream::Colors::GREEN);
   }

   // Print this node.
   simple_display(out, request);

   // Turn off the highlight.
   if (isHighlighted) {
      out.resetColor();
   }

   // Print the cached value, if known.
   auto cachedValue = cache.find(request);
   if (cachedValue != cache.end()) {
      out << " -> ";
      printEscapedString(cachedValue->second.getAsString(), out);
   }

   if (!visitedAnywhere.insert(request).second) {
      // We've already seed this node. Check whether it's part of a cycle.
      if (std::find(visitedAlongPath.begin(), visitedAlongPath.end(), request)
          != visitedAlongPath.end()) {
         // We have a cyclic dependency.
         out.changeColor(llvm::raw_ostream::RED);
         out << " (cyclic dependency)\n";
      } else {
         // We have seen this node before, but it's not a cycle. Elide its
         // children.
         out << " (elided)\n";
      }

      out.resetColor();
   } else if (dependencies.count(request) == 0) {
      // We have not seen this node before, so we don't know its dependencies.
      out.changeColor(llvm::raw_ostream::GREEN);
      out << " (dependency not evaluated)\n";
      out.resetColor();

      // Remove from the visited set.
      visitedAnywhere.erase(request);
   } else {
      // Print children.
      out << "\n";

      // Setup the prefix to print the children.
      prefixStr += ' ';
      prefixStr += (lastChild ? ' ' : '|');
      prefixStr += "  ";

      // Note that this request is along the path.
      visitedAlongPath.push_back(request);

      // Print the children.
      auto &dependsOn = dependencies.find(request)->second;
      for (unsigned i : indices(dependsOn)) {
         printDependencies(dependsOn[i], out, visitedAnywhere, visitedAlongPath,
                           highlightPath, prefixStr, i == dependsOn.size()-1);
      }

      // Drop our changes to the prefix.
      prefixStr.erase(prefixStr.end() - 4, prefixStr.end());

      // Remove from the visited set and path.
      visitedAnywhere.erase(request);
      assert(visitedAlongPath.back() == request);
      visitedAlongPath.pop_back();
   }
}

void Evaluator::dumpDependencies(const AnyRequest &request) const {
   printDependencies(request, llvm::dbgs());
}

void Evaluator::printDependenciesGraphviz(llvm::raw_ostream &out) const {
   // Form a list of all of the requests we know about.
   std::vector<AnyRequest> allRequests;
   for (const auto &knownRequest : dependencies) {
      allRequests.push_back(knownRequest.first);
   }

   // Sort the list of requests based on the display strings, so we get
   // deterministic output.
   auto getDisplayString = [&](const AnyRequest &request) {
      std::string result;
      {
         llvm::raw_string_ostream out(result);
         simple_display(out, request);
      }
      return result;
   };
   std::sort(allRequests.begin(), allRequests.end(),
             [&](const AnyRequest &lhs, const AnyRequest &rhs) {
                return getDisplayString(lhs) < getDisplayString(rhs);
             });

   // Manage request IDs to use in the resulting output graph.
   llvm::DenseMap<AnyRequest, unsigned> requestIDs;
   unsigned nextID = 0;

   // Prepopulate the known requests.
   for (const auto &request : allRequests) {
      requestIDs[request] = nextID++;
   }

   auto getRequestID = [&](const AnyRequest &request) {
      auto known = requestIDs.find(request);
      if (known != requestIDs.end()) return known->second;

      // We discovered a new request; record it's ID and add it to the list of
      // all requests.
      allRequests.push_back(request);
      requestIDs[request] = nextID;
      return nextID++;
   };

   auto getNodeName = [&](const AnyRequest &request) {
      std::string result;
      {
         llvm::raw_string_ostream out(result);
         out << "request_" << getRequestID(request);
      }
      return result;
   };

   // Emit the graph header.
   out << "digraph Dependencies {\n";

   // Emit the edges.
   llvm::DenseMap<AnyRequest, unsigned> inDegree;
   for (const auto &source : allRequests) {
      auto known = dependencies.find(source);
      assert(known != dependencies.end());
      for (const auto &target : known->second) {
         out << "  " << getNodeName(source) << " -> " << getNodeName(target)
             << ";\n";
         ++inDegree[target];
      }
   }

   out << "\n";

   static const char *colorNames[] = {
      "aquamarine",
      "blueviolet",
      "brown",
      "burlywood",
      "cadetblue",
      "chartreuse",
      "chocolate",
      "coral",
      "cornflowerblue",
      "crimson"
   };
   const unsigned numColorNames = sizeof(colorNames) / sizeof(const char *);

   llvm::DenseMap<unsigned, unsigned> knownBuffers;
   auto getColor = [&](const AnyRequest &request) -> Optional<const char *> {
      SourceLoc loc = request.getNearestLoc();
      if (loc.isInvalid())
         return None;

      unsigned bufferID = diags.SourceMgr.findBufferContainingLoc(loc);
      auto knownId = knownBuffers.find(bufferID);
      if (knownId == knownBuffers.end()) {
         knownId = knownBuffers.insert({bufferID, knownBuffers.size()}).first;
      }
      return colorNames[knownId->second % numColorNames];
   };

   // Emit the nodes.
   for (unsigned i : indices(allRequests)) {
      const auto &request = allRequests[i];
      out << "  " << getNodeName(request);
      out << " [label=\"";
      printEscapedString(request.getAsString(), out);

      auto cachedValue = cache.find(request);
      if (cachedValue != cache.end()) {
         out << " -> ";
         printEscapedString(cachedValue->second.getAsString(), out);
      }
      out << "\"";

      if (auto color = getColor(request)) {
         out << ", fillcolor=\"" << *color << "\"";
      }

      out << "];\n";
   }

   // Emit "fake" nodes for each of the source buffers we encountered, so
   // we know which file we're working from.
   // FIXME: This approximates a "top level" request for, e.g., type checking
   // an entire source file.
   std::vector<unsigned> sourceBufferIDs;
   for (const auto &element : knownBuffers) {
      sourceBufferIDs.push_back(element.first);
   }
   std::sort(sourceBufferIDs.begin(), sourceBufferIDs.end());
   for (unsigned bufferID : sourceBufferIDs) {
      out << "  buffer_" << bufferID << "[label=\"";
      printEscapedString(diags.SourceMgr.getIdentifierForBuffer(bufferID), out);
      out << "\"";

      out << ", shape=\"box\"";
      out << ", fillcolor=\""
          << colorNames[knownBuffers[bufferID] % numColorNames] << "\"";
      out << "];\n";
   }

   // Emit "false" dependencies from source buffer IDs to any requests that (1)
   // have no other incomining edges and (2) can be associated with a source
   // buffer.
   for (const auto &request : allRequests) {
      if (inDegree[request] > 0)
         continue;

      SourceLoc loc = request.getNearestLoc();
      if (loc.isInvalid())
         continue;

      unsigned bufferID = diags.SourceMgr.findBufferContainingLoc(loc);
      out << "  buffer_" << bufferID << " -> " << getNodeName(request) << ";\n";
   }

   // Done!
   out << "}\n";
}

void Evaluator::dumpDependenciesGraphviz() const {
   printDependenciesGraphviz(llvm::dbgs());
}

} // polar